#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import os

from gen_elemwise_utils import DTYPES


def main():
    parser = argparse.ArgumentParser(
        description="generate elemwise impl files",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--type",
        type=str,
        choices=["cuda", "hip"],
        default="cuda",
        help="generate cuda/hip elemwise special kernel file",
    )
    parser.add_argument("output", help="output directory")
    args = parser.parse_args()

    if not os.path.isdir(args.output):
        os.makedirs(args.output)

    if args.type == "cuda":
        cpp_ext = "cu"
    else:
        assert args.type == "hip"
        cpp_ext = "cpp.hip"

    for dtype in DTYPES.keys():
        fname = "special_{}.{}".format(dtype, cpp_ext)
        fname = os.path.join(args.output, fname)
        with open(fname, "w") as fout:
            w = lambda s: print(s, file=fout)

            w("// generated by gen_elemwise_special_kern_impls.py")
            if dtype == "dt_float16" or dtype == "dt_bfloat16":
                w("#if !MEGDNN_DISABLE_FLOAT16")
            w('#include "../special_kerns.inl"')
            w("INST(::megdnn::dtype::{})".format(DTYPES[dtype][0]))
            w("#undef INST")
            w("}")
            w("}")
            if dtype == "dt_float16" or dtype == "dt_bfloat16":
                w("#endif")

            print("generated {}".format(fname))

    os.utime(args.output)


if __name__ == "__main__":
    main()
